Bayesian Estimation of Differential Equations

Most of the scientific community deals with the basic problem of trying to mathematically model the reality around them and this often involves dynamical systems. The general trend to model these complex dynamical systems is through the use of differential equations. Differential equation models often have non-measurable parameters. The popular “forward-problem” of simulation consists of solving the differential equations for a given set of parameters, the “inverse problem” to simulation, known as parameter estimation, is the process of utilizing data to determine these model parameters. Bayesian inference provides a robust approach to parameter estimation with quantified uncertainty.

using Turing
using DifferentialEquations

# Load StatsPlots for visualizations and diagnostics.
using StatsPlots

using LinearAlgebra

# Set a seed for reproducibility.
using Random
Random.seed!(14);

The Lotka-Volterra Model

The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The populations change through time according to the pair of equations

\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\ \frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t) \end{aligned} \]

where \(x(t)\) and \(y(t)\) denote the populations of prey and predator at time \(t\), respectively, and \(\alpha, \beta, \gamma, \delta\) are positive parameters.

We implement the Lotka-Volterra model and simulate it with parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\).

# Define Lotka-Volterra model.
function lotka_volterra(du, u, p, t)
    # Model parameters.
    α, β, γ, δ = p
    # Current state.
    x, y = u

    # Evaluate differential equations.
    du[1] =- β * y) * x # prey
    du[2] =* x - γ) * y # predator

    return nothing
end

# Define initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 10.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Plot simulation.
plot(solve(prob, Tsit5()))

We generate noisy observations to use for the parameter estimation tasks in this tutorial. With the saveat argument we specify that the solution is stored only at 0.1 time units. To make the example more realistic we add random normally distributed noise to the simulation.

sol = solve(prob, Tsit5(); saveat=0.1)
odedata = Array(sol) + 0.8 * randn(size(Array(sol)))

# Plot simulation and noisy observations.
plot(sol; alpha=0.3)
scatter!(sol.t, odedata'; color=[1 2], label="")

Alternatively, we can use real-world data from Hudson’s Bay Company records (an Stan implementation with slightly different priors can be found here: https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html).

Direct Handling of Bayesian Estimation with Turing

Previously, functions in Turing and DifferentialEquations were not inter-composable, so Bayesian inference of differential equations needed to be handled by another package called DiffEqBayes.jl (note that DiffEqBayes works also with CmdStan.jl, Turing.jl, DynamicHMC.jl and ApproxBayes.jl - see the DiffEqBayes docs for more info).

Nowadays, however, Turing and DifferentialEquations are completely composable and we can just simulate differential equations inside a Turing @model. Therefore, we write the Lotka-Volterra parameter estimation problem using the Turing @model macro as below:

@model function fitlv(data, prob)
    # Prior distributions.
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)

    # Simulate Lotka-Volterra model. 
    p = [α, β, γ, δ]
    predicted = solve(prob, Tsit5(); p=p, saveat=0.1)

    # Observations.
    for i in 1:length(predicted)
        data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
    end

    return nothing
end

model = fitlv(odedata, prob)

# Sample 3 independent chains with forward-mode automatic differentiation (the default).
chain = sample(model, NUTS(), MCMCSerial(), 1000, 3; progress=false)
┌ Info: Found initial step size
└   ϵ = 0.025
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.025
Chains MCMC chain (1000×17×3 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 3
Samples per chain = 1000
Wall duration     = 51.13 seconds
Compute duration  = 48.68 seconds
parameters        = σ, α, β, γ, δ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           σ    0.8109    0.0408    0.0009   1983.5023   1667.1376    1.0028   ⋯
           α    1.5762    0.0510    0.0020    673.8528    823.3316    1.0091   ⋯
           β    1.0110    0.0456    0.0016    832.9322   1124.0233    1.0067   ⋯
           γ    2.8172    0.1324    0.0050    716.1514    887.4438    1.0085   ⋯
           δ    0.9122    0.0468    0.0018    692.8983    800.3465    1.0091   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           σ    0.7354    0.7821    0.8098    0.8381    0.8917
           α    1.4805    1.5419    1.5744    1.6087    1.6776
           β    0.9262    0.9793    1.0084    1.0415    1.1048
           γ    2.5655    2.7291    2.8142    2.9040    3.0862
           δ    0.8257    0.8807    0.9111    0.9424    1.0088

The estimated parameters are close to the parameter values the observations were generated with. We can also check visually that the chains have converged.

plot(chain)

Data retrodiction

In Bayesian analysis it is often useful to retrodict the data, i.e. generate simulated data using samples from the posterior distribution, and compare to the original data (see for instance section 3.3.2 - model checking of McElreath’s book “Statistical Rethinking”). Here, we solve the ODE for 300 randomly picked posterior samples in the chain. We plot the ensemble of solutions to check if the solution resembles the data. The 300 retrodicted time courses from the posterior are plotted in gray, the noisy observations are shown as blue and red dots, and the green and purple lines are the ODE solution that was used to generate the data.

plot(; legend=false)
posterior_samples = sample(chain[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob, Tsit5(); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])

We can see that, even though we added quite a bit of noise to the data the posterior distribution reproduces quite accurately the “true” ODE solution.

Lotka-Volterra model without data of prey

One can also perform parameter inference for a Lotka-Volterra model with incomplete data. For instance, let us suppose we have only observations of the predators but not of the prey. I.e., we fit the model only to the \(y\) variable of the system without providing any data for \(x\):

@model function fitlv2(data::AbstractVector, prob)
    # Prior distributions.
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)

    # Simulate Lotka-Volterra model but save only the second state of the system (predators).
    p = [α, β, γ, δ]
    predicted = solve(prob, Tsit5(); p=p, saveat=0.1, save_idxs=2)

    # Observations of the predators.
    data ~ MvNormal(predicted.u, σ^2 * I)

    return nothing
end

model2 = fitlv2(odedata[2, :], prob)

# Sample 3 independent chains.
chain2 = sample(model2, NUTS(0.45), MCMCSerial(), 5000, 3; progress=false)
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Info: Found initial step size
└   ϵ = 0.025
Chains MCMC chain (5000×17×3 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 3
Samples per chain = 5000
Wall duration     = 37.21 seconds
Compute duration  = 36.63 seconds
parameters        = σ, α, β, γ, δ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           σ    0.8533    0.0599    0.0083    52.9891    91.1634    1.2370     ⋯
           α    1.6688    0.1881    0.0246    55.6299   167.1523    1.1589     ⋯
           β    1.1023    0.1391    0.0178    58.6327   280.0447    1.1908     ⋯
           γ    2.8109    0.2568    0.0342    58.2688    83.0185    1.0950     ⋯
           δ    0.8080    0.2026    0.0277    57.1216   289.7429    1.1749     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           σ    0.7340    0.8172    0.8543    0.8952    0.9693
           α    1.3929    1.5313    1.6400    1.7833    2.1098
           β    0.8939    0.9947    1.0896    1.1819    1.4240
           γ    2.3031    2.6421    2.8085    2.9723    3.3202
           δ    0.4225    0.6607    0.7988    0.9577    1.1640

Again we inspect the trajectories of 300 randomly selected posterior samples.

plot(; legend=false)
posterior_samples = sample(chain2[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob, Tsit5(); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])

Note that here the observations of the prey (blue dots) were not used in the parameter estimation! Yet, the model can predict the values of \(x\) relatively accurately, albeit with a wider distribution of solutions, reflecting the greater uncertainty in the prediction of the \(x\) values.

Inference of Delay Differential Equations

Here we show an example of inference with another type of differential equation: a Delay Differential Equation (DDE). DDEs are differential equations where derivatives are function of values at an earlier point in time. This is useful to model a delayed effect, like incubation time of a virus for instance.

Here is a delayed version of the Lokta-Voltera system:

\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= \alpha x(t-\tau) - \beta y(t) x(t),\\ \frac{\mathrm{d}y}{\mathrm{d}t} &= - \gamma y(t) + \delta x(t) y(t), \end{aligned} \]

where \(\tau\) is a (positive) delay and \(x(t-\tau)\) is the variable \(x\) at an earlier time point \(t - \tau\).

The initial-value problem of the delayed system can be implemented as a DDEProblem. As described in the DDE example, here the function h is the history function that can be used to obtain a state at an earlier time point. Again we use parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\). Moreover, we assume \(x(t) = 1\) for \(t < 0\).

function delay_lotka_volterra(du, u, h, p, t)
    # Model parameters.
    α, β, γ, δ = p

    # Current state.
    x, y = u
    # Evaluate differential equations
    du[1] = α * h(p, t - 1; idxs=1) - β * x * y
    du[2] = -γ * y + δ * x * y

    return nothing
end

# Define initial-value problem.
p = (1.5, 1.0, 3.0, 1.0)
u0 = [1.0; 1.0]
tspan = (0.0, 10.0)
h(p, t; idxs::Int) = 1.0
prob_dde = DDEProblem(delay_lotka_volterra, u0, h, tspan, p);

We generate observations by adding normally distributed noise to the results of our simulations.

sol_dde = solve(prob_dde; saveat=0.1)
ddedata = Array(sol_dde) + 0.5 * randn(size(sol_dde))

# Plot simulation and noisy observations.
plot(sol_dde)
scatter!(sol_dde.t, ddedata'; color=[1 2], label="")

Now we define the Turing model for the Lotka-Volterra model with delay and sample 3 independent chains.

@model function fitlv_dde(data, prob)
    # Prior distributions.
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)

    # Simulate Lotka-Volterra model.
    p = [α, β, γ, δ]
    predicted = solve(prob, MethodOfSteps(Tsit5()); p=p, saveat=0.1)

    # Observations.
    for i in 1:length(predicted)
        data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
    end
end

model_dde = fitlv_dde(ddedata, prob_dde)

# Sample 3 independent chains.
chain_dde = sample(model_dde, NUTS(), MCMCSerial(), 300, 3; progress=false)
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Info: Found initial step size
└   ϵ = 0.00625
Chains MCMC chain (300×17×3 Array{Float64, 3}):

Iterations        = 151:1:450
Number of chains  = 3
Samples per chain = 300
Wall duration     = 13.96 seconds
Compute duration  = 13.58 seconds
parameters        = σ, α, β, γ, δ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           σ    0.5294    0.0259    0.0010   614.0110   553.9518    0.9993     ⋯
           α    1.4983    0.0750    0.0045   275.2620   317.8696    1.0075     ⋯
           β    1.0352    0.0605    0.0035   310.9576   336.6450    1.0091     ⋯
           γ    3.0133    0.1624    0.0093   304.0169   395.5705    1.0059     ⋯
           δ    1.0162    0.0572    0.0034   278.5049   335.7740    1.0075     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           σ    0.4835    0.5101    0.5275    0.5469    0.5854
           α    1.3678    1.4477    1.4890    1.5413    1.6626
           β    0.9281    0.9944    1.0302    1.0725    1.1662
           γ    2.6811    2.9022    3.0161    3.1220    3.3351
           δ    0.8944    0.9785    1.0161    1.0560    1.1303
plot(chain_dde)

Finally, plot trajectories of 300 randomly selected samples from the posterior. Again, the dots indicate our observations, the colored lines are the “true” simulations without noise, and the gray lines are trajectories from the posterior samples.

plot(; legend=false)
posterior_samples = sample(chain_dde[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob_dde, MethodOfSteps(Tsit5()); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol_dde; color=[1 2], linewidth=1)
scatter!(sol_dde.t, ddedata'; color=[1 2])

The fit is pretty good even though the data was quite noisy to start.

Scaling to Large Models: Adjoint Sensitivities

DifferentialEquations.jl’s efficiency for large stiff models has been shown in multiple benchmarks. To learn more about how to optimize solving performance for stiff problems you can take a look at the docs.

Sensitivity analysis, or automatic differentiation (AD) of the solver, is provided by the DiffEq suite. The model sensitivities are the derivatives of the solution with respect to the parameters. Specifically, the local sensitivity of the solution to a parameter is defined by how much the solution would change by changes in the parameter. Sensitivity analysis provides a cheap way to calculate the gradient of the solution which can be used in parameter estimation and other optimization tasks.

The AD ecosystem in Julia allows you to switch between forward mode, reverse mode, source to source and other choices of AD and have it work with any Julia code. For a user to make use of this within SciML, high level interactions in solve automatically plug into those AD systems to allow for choosing advanced sensitivity analysis (derivative calculation) methods.

More theoretical details on these methods can be found at: https://docs.sciml.ai/latest/extras/sensitivity_math/.

While these sensitivity analysis methods may seem complicated, using them is dead simple. Here is a version of the Lotka-Volterra model using adjoint sensitivities.

All we have to do is switch the AD backend to one of the adjoint-compatible backends (ReverseDiff, Tracker, or Zygote)! Notice that on this model adjoints are slower. This is because adjoints have a higher overhead on small parameter models and therefore we suggest using these methods only for models with around 100 parameters or more. For more details, see https://arxiv.org/abs/1812.01892.

using Zygote, SciMLSensitivity

# Sample a single chain with 1000 samples using Zygote.
sample(model, NUTS(;adtype=AutoZygote()), 1000; progress=false)
┌ Info: Found initial step size
└   ϵ = 0.05
Chains MCMC chain (1000×17×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 470.51 seconds
Compute duration  = 470.51 seconds
parameters        = σ, α, β, γ, δ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           σ    2.1295    0.1146    0.0056   402.8176   323.4393    1.0019     ⋯
           α    1.2215    0.1077    0.0219    31.0583    27.5086    1.0670     ⋯
           β    0.8775    0.1235    0.0070   291.5847   193.7163    1.0059     ⋯
           γ    1.1262    0.1423    0.0295    29.9477    34.6927    1.0715     ⋯
           δ    0.5417    0.0778    0.0145    41.7733    29.7712    1.0501     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           σ    1.9234    2.0530    2.1221    2.2034    2.3652
           α    0.9490    1.1680    1.2394    1.2976    1.3841
           β    0.6617    0.7927    0.8703    0.9416    1.1559
           γ    1.0039    1.0303    1.0806    1.1588    1.5897
           δ    0.4402    0.4889    0.5258    0.5714    0.7494

If desired, we can control the sensitivity analysis method that is used by providing the sensealg keyword argument to solve. Here we will not choose a sensealg and let it use the default choice:

@model function fitlv_sensealg(data, prob)
    # Prior distributions.
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)

    # Simulate Lotka-Volterra model and use a specific algorithm for computing sensitivities.
    p = [α, β, γ, δ]
    predicted = solve(prob; p=p, saveat=0.1)

    # Observations.
    for i in 1:length(predicted)
        data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
    end

    return nothing
end;

model_sensealg = fitlv_sensealg(odedata, prob)

# Sample a single chain with 1000 samples using Zygote.
sample(model_sensealg, NUTS(;adtype=AutoZygote()), 1000; progress=false)
┌ Info: Found initial step size
└   ϵ = 0.2
Chains MCMC chain (1000×17×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 442.14 seconds
Compute duration  = 442.14 seconds
parameters        = σ, α, β, γ, δ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           σ    2.2258    0.1090    0.0040   770.5695   604.8066    1.0007     ⋯
           α    2.1371    0.1302    0.0101   185.8600   151.5148    1.0014     ⋯
           β    1.8659    0.1098    0.0042   478.9176   395.7409    1.0007     ⋯
           γ    3.5237    0.2585    0.0188   189.9886   146.7963    1.0008     ⋯
           δ    1.6630    0.1374    0.0082   279.1140   417.1714    1.0031     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           σ    2.0313    2.1473    2.2205    2.2939    2.4610
           α    1.9181    2.0344    2.1253    2.2165    2.4150
           β    1.5858    1.8101    1.8903    1.9516    1.9955
           γ    2.9790    3.3445    3.5339    3.7234    3.9594
           δ    1.4157    1.5616    1.6627    1.7597    1.9224

For more examples of adjoint usage on large parameter models, consult the DiffEqFlux documentation.

Back to top